{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e6b84a1f",
   "metadata": {},
   "source": [
    "# Supplementary Notebook for \"Training Your Own Small Language Model from Scratch\"\n",
    "\n",
    "This iPython notebook is a supplementary notebook, which is required to complete the steps in the [the main notebok](slm_pretraining_sft.ipynb). Please refer to [slm_pretraining_sft.ipynb](slm_pretraining_sft.ipynb) for details about the environment setting including Docker container.\n",
    "\n",
    "\n",
    "## Step 4: Launch a Megatron GPT Server\n",
    "\n",
    "During Step 4 in [slm_pretraining_sft.ipynb](slm_pretraining_sft.ipynb), you need to launch a text generation server (using [megatron_gpt_eval.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_eval.py).) to get generations from the SLM that you trained. \n",
    "\n",
    "Run the following command to launch a Megatron GPT server. Make sure to use the correct .nemo filepath."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b86e1f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 4\n",
    "MODEL_FILE = \"results/megatron_gpt/checkpoints/megatron_gpt.nemo\"\n",
    "\n",
    "!python /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_eval.py \\\n",
    "  gpt_model_file={MODEL_FILE} \\\n",
    "  server=True \\\n",
    "  port=55555"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "51576917",
   "metadata": {},
   "source": [
    "After launching a server, you can get generations by using Python code like below:\n",
    "\n",
    "```\n",
    "import json\n",
    "import requests\n",
    "\n",
    "class MegatronGPTEvalClient:\n",
    "    def __init__(self,\n",
    "                 batch_size: int = 1,\n",
    "                 port_num: int = 55555,\n",
    "                 headers = {\"Content-Type\": \"application/json\"}):\n",
    "        self.batch_size = batch_size\n",
    "        self.port_num = port_num\n",
    "        self.headers = headers\n",
    "\n",
    "    def generate(self, text: str):\n",
    "        data = {\"sentences\": [text] * self.batch_size,\n",
    "                \"tokens_to_generate\": 32,\n",
    "                \"temperature\": 1.0,\n",
    "                \"add_BOS\": True,\n",
    "                \"top_k\": 0,\n",
    "                \"top_p\": 0.9,\n",
    "                \"greedy\": False,\n",
    "                \"all_probs\": False,\n",
    "                \"repetition_penalty\": 1.2,\n",
    "                \"min_tokens_to_generate\": 2}\n",
    "        resp = requests.put('http://localhost:{}/generate'.format(self.port_num),\n",
    "                            data=json.dumps(data),\n",
    "                            headers=self.headers)\n",
    "        sentences = resp.json()['sentences']\n",
    "        generation = sentences[0]\n",
    "        return generation[len(text):].lstrip()\n",
    "\n",
    "\n",
    "client = MegatronGPTEvalClient()\n",
    "client.generate(\"How are you?\")\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55418d2d",
   "metadata": {},
   "source": [
    "## Step 5: Launch a Megatron GPT Server after Fine-Tuning (Advanced)\n",
    "\n",
    "\n",
    "\n",
    "### Step 5-1: .nemo File Conversion\n",
    "\n",
    "This conversion step is required as NeMo-Aligner's script creates a slightly different .nemo checkpoint, which is not compatible with [megatron_gpt_eval.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_eval.py).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee6f0682",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Confirm if the .nemo file exists\n",
    "!ls results/megatron_gpt_sft/checkpoints/megatron_gpt_sft.nemo"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad39ce5a",
   "metadata": {},
   "source": [
    "Run the following commands to convert the .nemo into a new .nemo checkpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d63f8119",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Untar the .nemo checkpoint\n",
    "!cd results/megatron_gpt_sft/checkpoints && mkdir -p megatron_gpt_sft_fixed && tar -xf megatron_gpt_sft.nemo -C megatron_gpt_sft_fixed\n",
    "\n",
    "# Update the target value\n",
    "!cd results/megatron_gpt_sft/checkpoints/megatron_gpt_sft_fixed && sed -i 's|target: nemo_aligner.models.nlp.gpt.gpt_sft_model.GPTSFTModel|target: nemo.collections.nlp.models.language_modeling.megatron_gpt_model.MegatronGPTModel|' model_config.yaml\n",
    "\n",
    "# Tar the files to create a new .nemo\n",
    "!cd results/megatron_gpt_sft/checkpoints/megatron_gpt_sft_fixed && tar -cf ../megatron_gpt_sft_fixed.nemo ./*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2209f0b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Confirm if the new .nemo file exists\n",
    "!ls results/megatron_gpt_sft/checkpoints/megatron_gpt_sft_fixed.nemo"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd0c27c8",
   "metadata": {},
   "source": [
    "### Step 5-2: Launch a Megatron GPT Server\n",
    "\n",
    "The newly created .nemo checkpoint is compatible with [megatron_gpt_eval.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_eval.py) and we can launch a text generation server in the same manner as the .nemo checkpoint for pre-training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b208968f",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_FILE = \"results/megatron_gpt_sft/checkpoints/megatron_gpt_sft_fixed.nemo\"\n",
    "\n",
    "!python /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_eval.py \\\n",
    "  gpt_model_file={MODEL_FILE} \\\n",
    "  server=True \\\n",
    "  port=55555"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
